import os
import openai
import json
from prompta.utils.set_api import *
from typing import Iterable


code_oracle_client = openai.OpenAI()


def gen_program_by_nl_description(client, description: str, alphabet: Iterable[str]) -> str:
    sys_prompt = f"""
    You are a helpful assistant that generates a Python function to determine whether a given word is accepted by a described DFA.
    You will receive a natural language description of the DFA and its alphabet. You don't have to create a state transition table.
    You can use the easiest way to implement the function.
    You should import all the packages you used within the function.

    The function you generate should:
    - Take a single input: a word represented as a tuple of symbols (Tuple[str, ...]).
    - Return a boolean: True if the word is accepted by the DFA, False otherwise.

    For example, if the DFA accepts words with an even number of 'a's, and the input is ('a', 'b', 'b'), the function should return True.
    """
    user_prompt = f"""
        Here is the description:
        {description}

        Here is the alphabet:
        {alphabet}

        Please generate the function.
    """
    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": sys_prompt},
            {"role": "user", "content": user_prompt},
        ],
    )
    return response.choices[0].message.content


def run_gen_program_by_task(client):
    from prompta.core.language import RuleBasedLanguage, TASK_REGISTRY
    language_dir = '../../pipelines/prompta/test/targets/{}.taf'
    programs_save_dir = '../../pipelines/prompta/test/programs'
    if not os.path.exists(programs_save_dir):
        os.makedirs(programs_save_dir)
    for difficulty, num_tasks in TASK_REGISTRY.items():
        for index in range(1, num_tasks + 1):
            task_target_name = '{}{:02d}'.format(difficulty, index)
            language = RuleBasedLanguage(language_dir.format(task_target_name))
            description = language.ctx.definition()
            print(description)
            if os.path.exists(os.path.join(programs_save_dir, '{}.txt'.format(task_target_name))):
                print(f'{task_target_name} already exists')
                continue
            assert os.path.exists(language_dir.format(task_target_name))
            
            alphabet = list(language.alphabet)
            response = gen_program_by_nl_description(client, description, alphabet)
            with open(os.path.join(programs_save_dir, '{}.txt'.format(task_target_name)), 'w') as f:
                f.write(response)
            print(f'{task_target_name} done')


if __name__ == "__main__":
    run_gen_program_by_task(code_oracle_client)

